Skip to content

ext/Mooncake: handle ComponentArray cotangents at @from_rrule boundaries#352

Merged
ChrisRackauckas merged 1 commit intoSciML:mainfrom
ChrisRackauckas-Claude:mooncake-ext-widen-dispatch
Apr 11, 2026
Merged

ext/Mooncake: handle ComponentArray cotangents at @from_rrule boundaries#352
ChrisRackauckas merged 1 commit intoSciML:mainfrom
ChrisRackauckas-Claude:mooncake-ext-widen-dispatch

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

  • Widen ComponentArraysMooncakeExt.increment_and_get_rdata! to cover the cotangent shapes that real ChainRulesCore.rrules produce for ComponentArray primals, so downstream packages that declare an @from_rrule / @from_chainrules boundary with a ComponentArray argument actually work.
  • Add Mooncake to test/autodiff with a focused @testset that exercises both native Mooncake (prepare_gradient_cache / value_and_gradient!! over nested ComponentArray(; u0, p_all)) and the @from_rrule round-trip that the new dispatch targets.

Motivation

SciML/SciMLSensitivity.jl#1419 migrates the docs tutorials from Zygote to Mooncake, but several of them (flagged as !!! note in that PR — second_order_neural.md, brusselator.md, feedback_control.md, …) have to stay on Zygote because they hit

ArgumentError: The fdata type Mooncake.FData{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoFData}},
rdata type Mooncake.NoRData, and tangent type ComponentVector{Float64, Vector{Float64}, …} combination
is not supported with @from_chainrules or @from_rrule. …

The root cause is that the existing extension method

function Mooncake.increment_and_get_rdata!(
        f::Mooncake.FData{@NamedTuple{data::A, axes::Mooncake.NoFData}},
        r::Mooncake.NoRData,
        t::A,
    ) where {P <: Union{IEEEFloat, Complex{<:IEEEFloat}}, A <: Array{P}}

only matches when the ChainRules tangent t is a raw Array{P}. In practice the rrules defined in src/compat/chainrulescore.jl (for getproperty, getdata, Type{ComponentArray}(data, axes), Type{CA}(nt::NamedTuple)) return cotangents that are themselves ComponentArrays (either flat-Array-backed or, when a view is involved, SubArray-backed). Any SciML package that declares a Mooncake primitive with a ComponentVector argument therefore funnels through increment_and_get_rdata! with a ComponentArray tangent and hits the fallback error above.

What this PR adds

Three additional methods on Mooncake.increment_and_get_rdata!:

  1. Flat-Array-backed CV fdata + ComponentArray cotangent. Unwrap via getdata(t) and delegate to the underlying storage. This is the common case — a loss function takes a ComponentVector{Float64, Vector{Float64}}, the rrule returns ComponentArray(Δ, getaxes(x)).
  2. SubArray-backed CV fdata + Array cotangent. Produced whenever getproperty(::ComponentVector, ::Symbol) or any other view-producing operation crosses an @from_rrule boundary. We aggregate into the parent-array slot of the SubArray's structural tangent for the full-parent-coverage case, which is what actually lands at these boundaries in practice.
  3. SubArray-backed CV fdata + ComponentArray cotangent. Same as (2), but first getdata(t).

Cases (2) and (3) raise a clear ArgumentError for the partial-view case (where the view's linear indices can't be recovered from fdata alone), so we never silently misplace gradient mass. Opening an issue with a reproducer is straightforward if anyone hits that path.

The existing raw-Array method and Mooncake.friendly_tangent_cache definition are preserved verbatim.

Tests

New @testset \"Mooncake\" in test/autodiff/autodiff_tests.jl:

  • Native Mooncake prepare_gradient_cache / value_and_gradient!! on a flat ComponentVector and on a nested ComponentArray(; u0, p_all) layout (matches the feedback_control.md shape from SciML/SciMLSensitivity.jl#1419).
  • A synthetic sum_abs2 with a hand-written ChainRulesCore.rrule whose pullback returns a ComponentArray cotangent, declared as a Mooncake primitive via @from_rrule. Two cases: a flat ComponentVector, and a nested ComponentArray(; u0 = Vector, p_all = ComponentArray). Both paths fail on main and pass after this patch.
  • A smoke check that Mooncake.friendly_tangent_cache(::ComponentArray) still returns a FriendlyTangentCache{AsPrimal}.

test/autodiff/Project.toml grows Mooncake = \"0.5.26\" and ChainRulesCore = \"1\". The 0.5.26 pin matches the friendly_tangent_cache symbol the extension already references (it doesn't exist in 0.5.24 and earlier — that precompile failure is how I discovered the existing extension has an implicit floor that the main Project.toml's Mooncake = \"0.5\" doesn't encode; worth tightening in a follow-up but orthogonal to this PR).

Local results on Julia 1.12:

  • GROUP=Autodiff56/56 pass (49 prior + 7 new asserts), 3m24s
  • GROUP=Core459 pass / 9 pre-existing broken, 1m45s

Not in scope

The MooncakeRuleCompilationError mentioned in the feedback_control.md / brusselator.md notes of SciML/SciMLSensitivity.jl#1419 is a compile-time failure inside Mooncake's rule builder rather than an increment_and_get_rdata! dispatch gap — I couldn't reproduce it from any standalone ComponentArrays snippet, so it appears to originate in SciMLBase/SciMLSensitivity's adjoint stack rather than in ComponentArrays itself. This PR is targeted at the runtime increment_and_get_rdata! gap the notes explicitly attribute to "ComponentArrays' Mooncake extension".

Test plan

  • GROUP=Autodiff julia --project=test/autodiff test/runtests.jl — 56/56 pass
  • GROUP=Core julia --project=test test/runtests.jl — 459/459 (+9 pre-existing broken)
  • CI: full matrix

🤖 Generated with Claude Code

The existing `increment_and_get_rdata!` method only matched a raw
`Array{P}` tangent against a flat-`Array`-backed ComponentVector fdata.
In practice the tangent coming out of a `ChainRulesCore.rrule` for a
ComponentArray primal is usually *another* ComponentArray (e.g. via
`ComponentArray(Δ, getaxes(x))`), so downstream packages that declare a
`@from_rrule` / `@from_chainrules` boundary with a ComponentArray
argument hit

    ArgumentError: The fdata type ... ComponentVector{...} combination
    is not supported with @from_chainrules or @from_rrule.

This is what blocked the Mooncake migration of the SciMLSensitivity.jl
tutorials in SciML/SciMLSensitivity.jl#1419 (the `feedback_control.md`
and `second_order_neural.md` notes). Widen the dispatch to cover:

  - flat-`Array`-backed ComponentVector fdata with an incoming
    `ComponentArray` cotangent (unwrap to the underlying storage),
  - SubArray-backed ComponentVector fdata (produced by
    `getproperty(::ComponentVector, ::Symbol)`) with either an `Array`
    or a `ComponentArray` cotangent — handled for the common
    full-parent-coverage case, with a clear `ArgumentError` for the
    partial-view case that would otherwise silently misplace gradient
    mass.

Tests: exercise both native Mooncake (`prepare_gradient_cache` +
`value_and_gradient!!` over nested `ComponentArray(; u0, p_all)`) and
the `@from_rrule` round-trip path that the new methods target. Adds
Mooncake to `test/autodiff/Project.toml` (pinned to `0.5.26` to match
the `friendly_tangent_cache` symbol the extension already references).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants